//
// Copyright 2018 Ettus Research, A National Instruments Company
//
// SPDX-License-Identifier: LGPL-3.0-or-later
//
// Module: mesh_2d_dor_router_single_sw
// Description: 
//   This module implements the router for a 2-dimentional (2d) 
//   mesh network that uses dimension order routing (dor) and has a 
//   single underlying switch (single_sw). It uses AXI-Stream for all of its
//   links. 
//   The mesh topology, routing algorithms and the router architecture is 
//   described in README.md in this directory. 
// Parameters:
//   - WIDTH: Width of the AXI-Stream data bus
//   - DIM_SIZE: Number of routers alone one dimension
//   - XB_ADDR_X: The X-coordinate of this router in the topology
//   - XB_ADDR_Y: The Y-coordinate of this router in the topology
//   - TERM_BUFF_SIZE: log2 of the ingress terminal buffer size (in words)
//   - XB_BUFF_SIZE: log2 of the ingress inter-router buffer size (in words)
//   - ROUTING_ALLOC: Algorithm to allocate routing paths between routers.
//     * WORMHOLE: Allocate route as soon as first word in pkt arrives
//     * CUT-THROUGH: Allocate route only after the full pkt arrives
//   - SWITCH_ALLOC: Algorithm to allocate the switch
//     * PRIO: Priority based. Priority: Y-dim > X-dim > Term
//     * ROUND-ROBIN: Round robin input port allocation
// Signals:
//   - *_axis_ter_*: Terminal ports (master/slave)
//   - *_axis_wst_*: Inter-router X-dim west connections (master/slave)
//   - *_axis_est_*: Inter-router X-dim east connections (master/slave)
//   - *_axis_nor_*: Inter-router X-dim north connections (master/slave)
//   - *_axis_sou_*: Inter-router X-dim south connections (master/slave)
//

module mesh_2d_dor_router_single_sw #(
  parameter                        WIDTH          = 64,
  parameter                        DIM_SIZE       = 4,
  parameter [$clog2(DIM_SIZE)-1:0] XB_ADDR_X      = 0,
  parameter [$clog2(DIM_SIZE)-1:0] XB_ADDR_Y      = 0,
  parameter                        TERM_BUFF_SIZE = 5,
  parameter                        XB_BUFF_SIZE   = 5,
  parameter                        ROUTING_ALLOC  = "WORMHOLE",  // Routing (switching) method {WORMHOLE, CUT-THROUGH}
  parameter                        SWITCH_ALLOC   = "PRIO"       // Switch allocation algorithm {ROUND-ROBIN, PRIO}
) (
  // Clocks and resets
  input  wire             clk,
  input  wire             reset,

  // Terminal connections
  input  wire [WIDTH-1:0] s_axis_ter_tdata,
  input  wire             s_axis_ter_tlast,
  input  wire             s_axis_ter_tvalid,
  output wire             s_axis_ter_tready,
  output wire [WIDTH-1:0] m_axis_ter_tdata,
  output wire             m_axis_ter_tlast,
  output wire             m_axis_ter_tvalid,
  input  wire             m_axis_ter_tready,

  // West inter-router connections
  input  wire [WIDTH-1:0] s_axis_wst_tdata,
  input  wire [0:0]       s_axis_wst_tdest,
  input  wire             s_axis_wst_tlast,
  input  wire             s_axis_wst_tvalid,
  output wire             s_axis_wst_tready,
  output wire [WIDTH-1:0] m_axis_wst_tdata,
  output wire [0:0]       m_axis_wst_tdest,
  output wire             m_axis_wst_tlast,
  output wire             m_axis_wst_tvalid,
  input  wire             m_axis_wst_tready,

  // East inter-router connections
  input  wire [WIDTH-1:0] s_axis_est_tdata,
  input  wire [0:0]       s_axis_est_tdest,
  input  wire             s_axis_est_tlast,
  input  wire             s_axis_est_tvalid,
  output wire             s_axis_est_tready,
  output wire [WIDTH-1:0] m_axis_est_tdata,
  output wire [0:0]       m_axis_est_tdest,
  output wire             m_axis_est_tlast,
  output wire             m_axis_est_tvalid,
  input  wire             m_axis_est_tready,

  // North inter-router connections
  input  wire [WIDTH-1:0] s_axis_nor_tdata,
  input  wire [0:0]       s_axis_nor_tdest,
  input  wire             s_axis_nor_tlast,
  input  wire             s_axis_nor_tvalid,
  output wire             s_axis_nor_tready,
  output wire [WIDTH-1:0] m_axis_nor_tdata,
  output wire [0:0]       m_axis_nor_tdest,
  output wire             m_axis_nor_tlast,
  output wire             m_axis_nor_tvalid,
  input  wire             m_axis_nor_tready,

  // South inter-router connections
  input  wire [WIDTH-1:0] s_axis_sou_tdata,
  input  wire [0:0]       s_axis_sou_tdest,
  input  wire             s_axis_sou_tlast,
  input  wire             s_axis_sou_tvalid,
  output wire             s_axis_sou_tready,
  output wire [WIDTH-1:0] m_axis_sou_tdata,
  output wire [0:0]       m_axis_sou_tdest,
  output wire             m_axis_sou_tlast,
  output wire             m_axis_sou_tvalid,
  input  wire             m_axis_sou_tready
);
  // -------------------------------------------------
  // Routing functions
  // -------------------------------------------------

  // mesh_node_mapping.vh file contains the mapping between the node number
  // and its XY coordinates. It is autogenerated and defines the node_to_xdst
  // and node_to_ydst functions.
  `include "mesh_node_mapping.vh"

  localparam [2:0] SW_DEST_TER  = 3'd0;
  localparam [2:0] SW_DEST_WST  = 3'd1;
  localparam [2:0] SW_DEST_EST  = 3'd2;
  localparam [2:0] SW_DEST_NOR  = 3'd3;
  localparam [2:0] SW_DEST_SOU  = 3'd4;
  localparam [2:0] SW_NUM_DESTS = 3'd5;

  // The compute_switch_tdest function is the destination selector
  // i.e. it will inspecte the bottom $clog2(DIM_SIZE)*2 bits of the
  // first word of a packet and determine the destination of the packet. 
  function [3:0] compute_switch_tdest;
    input [WIDTH-1:0] header;
    input [3:0]       src;
    reg [$clog2(DIM_SIZE)-1:0] xdst, ydst;
    reg signed [$clog2(DIM_SIZE):0] xdiff, ydiff;
  begin
    xdst  = node_to_xdst(header);
    ydst  = node_to_ydst(header);
    xdiff = xdst - XB_ADDR_X;
    ydiff = ydst - XB_ADDR_Y;
    // Routing logic
    // - MSB is the VC, 3 LSBs are the router destination
    // - VC in a mesh is always 0
    if (xdiff == 'd0 && ydiff == 'd0) begin
      // VC=0 because terminals don't have VCs
      compute_switch_tdest = {1'b0, SW_DEST_TER};
    end else if (xdiff == 'd0) begin
      // VC=1 for CCW turns and VC=0 for everything else
      if (ydiff < 0)
        compute_switch_tdest = {(src == SW_DEST_WST), SW_DEST_NOR};
      else
        compute_switch_tdest = {(src == SW_DEST_EST), SW_DEST_SOU};
    end else begin
      // VC=0 because east-west paths don't have VCs
      if (xdiff < 0)
        compute_switch_tdest = {1'b0, SW_DEST_WST}; 
      else
        compute_switch_tdest = {1'b0, SW_DEST_EST};
    end
    if (xdst != 'hx && ydst != 'hx) begin
      if (XB_ADDR_X == 0 && compute_switch_tdest == SW_DEST_WST)
        $display("Illegal route chosen: WEST. xdst=%d, ydst=%d, xaddr=%d, yaddr=%d", xdst, ydst, XB_ADDR_X, XB_ADDR_Y);
      if (XB_ADDR_X == DIM_SIZE-1 && compute_switch_tdest == SW_DEST_EST)
        $display("Illegal route chosen: EAST. xdst=%d, ydst=%d, xaddr=%d, yaddr=%d", xdst, ydst, XB_ADDR_X, XB_ADDR_Y);
      if (XB_ADDR_Y == 0 && compute_switch_tdest == SW_DEST_NOR)
        $display("Illegal route chosen: NORTH. xdst=%d, ydst=%d, xaddr=%d, yaddr=%d", xdst, ydst, XB_ADDR_X, XB_ADDR_Y);
      if (XB_ADDR_Y == DIM_SIZE-1 && compute_switch_tdest == SW_DEST_SOU)
        $display("Illegal route chosen: SOUTH. xdst=%d, ydst=%d, xaddr=%d, yaddr=%d", xdst, ydst, XB_ADDR_X, XB_ADDR_Y);
    end
    //$display("xdst=%d, ydst=%d, xaddr=%d, yaddr=%d, dst=%d", xdst, ydst, XB_ADDR_X, XB_ADDR_Y, compute_switch_tdest);
  end
  endfunction

  // The compute_switch_alloc function is the switch allocation function 
  // i.e. it chooses which input port reserves the switch for packet transfer.
  // After the switch is allocated, all other ports will be backpressured until
  // the packet finishes transferring.
  function [2:0] compute_switch_alloc;
    input [4:0] pkt_waiting;
    input [2:0] last_alloc;
  begin
    if (pkt_waiting == 5'b00000) begin
      compute_switch_alloc = SW_DEST_TER;
    end else if (pkt_waiting == 5'b00001) begin
      compute_switch_alloc = SW_DEST_TER;
    end else if (pkt_waiting == 5'b00010) begin
      compute_switch_alloc = SW_DEST_WST;
    end else if (pkt_waiting == 5'b00100) begin
      compute_switch_alloc = SW_DEST_EST;
    end else if (pkt_waiting == 5'b01000) begin
      compute_switch_alloc = SW_DEST_NOR;
    end else if (pkt_waiting == 5'b10000) begin
      compute_switch_alloc = SW_DEST_SOU;
    end else begin
      if (SWITCH_ALLOC == "PRIO") begin
        // Priority: South > East > North > West > Term
        if (pkt_waiting[SW_DEST_SOU])
          compute_switch_alloc = SW_DEST_SOU;
        else if (pkt_waiting[SW_DEST_EST])
          compute_switch_alloc = SW_DEST_EST;
        else if (pkt_waiting[SW_DEST_NOR])
          compute_switch_alloc = SW_DEST_NOR;
        else if (pkt_waiting[SW_DEST_WST])
          compute_switch_alloc = SW_DEST_WST;
        else
          compute_switch_alloc = SW_DEST_TER;
      end else begin
        // Round-robin
        if (pkt_waiting[(last_alloc + 3'd1) % SW_NUM_DESTS])
          compute_switch_alloc = (last_alloc + 3'd1) % SW_NUM_DESTS;
        else if (pkt_waiting[(last_alloc + 3'd2) % SW_NUM_DESTS])
          compute_switch_alloc = (last_alloc + 3'd2) % SW_NUM_DESTS;
        else if (pkt_waiting[(last_alloc + 3'd3) % SW_NUM_DESTS])
          compute_switch_alloc = (last_alloc + 3'd3) % SW_NUM_DESTS;
        else if (pkt_waiting[(last_alloc + 3'd4) % SW_NUM_DESTS])
          compute_switch_alloc = (last_alloc + 3'd4) % SW_NUM_DESTS;
        else
          compute_switch_alloc = last_alloc;
      end
    end
    //$display("pkt_waiting=%b, alloc=%d, last_alloc=%d", pkt_waiting, compute_switch_alloc, last_alloc);
  end
  endfunction

  // -------------------------------------------------
  // Input buffers
  // -------------------------------------------------
  wire [WIDTH-1:0] ter_i_tdata;
  wire [3:0]       ter_i_tdest;
  wire             ter_i_tlast;
  wire             ter_i_tvalid;
  wire             ter_i_tready;

  // Data coming in from the terminal is gated until a full packet arrives
  // in order to minimize the switch allocation time per packet.
  axi_packet_gate #(
    .WIDTH(WIDTH), .SIZE(TERM_BUFF_SIZE)
  ) term_in_pkt_gate_i (
    .clk      (clk), 
    .reset    (reset), 
    .clear    (1'b0),
    .i_tdata  (s_axis_ter_tdata),
    .i_tlast  (s_axis_ter_tlast),
    .i_tvalid (s_axis_ter_tvalid),
    .i_tready (s_axis_ter_tready),
    .i_terror (1'b0),
    .o_tdata  (ter_i_tdata),
    .o_tlast  (ter_i_tlast),
    .o_tvalid (ter_i_tvalid),
    .o_tready (ter_i_tready)
  );
  assign ter_i_tdest = compute_switch_tdest(ter_i_tdata, SW_DEST_TER);

  wire [WIDTH-1:0] wst_i_tdata,  est_i_tdata,  nor_i_tdata,  sou_i_tdata;
  wire [3:0]       wst_i_tdest,  est_i_tdest,  nor_i_tdest,  sou_i_tdest;
  wire             wst_i_tlast,  est_i_tlast,  nor_i_tlast,  sou_i_tlast;
  wire             wst_i_tvalid, est_i_tvalid, nor_i_tvalid, sou_i_tvalid;
  wire             wst_i_tready, est_i_tready, nor_i_tready, sou_i_tready;

  axis_ingress_vc_buff #(
    .WIDTH(WIDTH), .NUM_VCS(1),
    .SIZE(XB_BUFF_SIZE),
    .ROUTING(ROUTING_ALLOC)
  ) wst_in_vc_buf_i (
    .clk           (clk), 
    .reset         (reset), 
    .s_axis_tdata  (s_axis_wst_tdata),
    .s_axis_tdest  (s_axis_wst_tdest),
    .s_axis_tlast  (s_axis_wst_tlast),
    .s_axis_tvalid (s_axis_wst_tvalid),
    .s_axis_tready (s_axis_wst_tready),
    .m_axis_tdata  (wst_i_tdata),
    .m_axis_tlast  (wst_i_tlast),
    .m_axis_tvalid (wst_i_tvalid),
    .m_axis_tready (wst_i_tready)
  );
  assign wst_i_tdest = compute_switch_tdest(wst_i_tdata, SW_DEST_WST);

  axis_ingress_vc_buff #(
    .WIDTH(WIDTH), .NUM_VCS(1),
    .SIZE(XB_BUFF_SIZE),
    .ROUTING(ROUTING_ALLOC)
  ) est_in_vc_buf_i (
    .clk           (clk), 
    .reset         (reset), 
    .s_axis_tdata  (s_axis_est_tdata),
    .s_axis_tdest  (s_axis_est_tdest),
    .s_axis_tlast  (s_axis_est_tlast),
    .s_axis_tvalid (s_axis_est_tvalid),
    .s_axis_tready (s_axis_est_tready),
    .m_axis_tdata  (est_i_tdata),
    .m_axis_tlast  (est_i_tlast),
    .m_axis_tvalid (est_i_tvalid),
    .m_axis_tready (est_i_tready)
  );
  assign est_i_tdest = compute_switch_tdest(est_i_tdata, SW_DEST_EST);

  axis_ingress_vc_buff #(
    .WIDTH(WIDTH), .NUM_VCS(2), // Only north-south traffic has VCs
    .SIZE(XB_BUFF_SIZE),
    .ROUTING(ROUTING_ALLOC)
  ) nor_in_vc_buf_i (
    .clk           (clk), 
    .reset         (reset), 
    .s_axis_tdata  (s_axis_nor_tdata),
    .s_axis_tdest  (s_axis_nor_tdest),
    .s_axis_tlast  (s_axis_nor_tlast),
    .s_axis_tvalid (s_axis_nor_tvalid),
    .s_axis_tready (s_axis_nor_tready),
    .m_axis_tdata  (nor_i_tdata),
    .m_axis_tlast  (nor_i_tlast),
    .m_axis_tvalid (nor_i_tvalid),
    .m_axis_tready (nor_i_tready)
  );
  assign nor_i_tdest = compute_switch_tdest(nor_i_tdata, SW_DEST_NOR);

  axis_ingress_vc_buff #(
    .WIDTH(WIDTH), .NUM_VCS(2), // Only north-south traffic has VCs
    .SIZE(XB_BUFF_SIZE),
    .ROUTING(ROUTING_ALLOC)
  ) sou_in_vc_buf_i (
    .clk           (clk), 
    .reset         (reset), 
    .s_axis_tdata  (s_axis_sou_tdata),
    .s_axis_tdest  (s_axis_sou_tdest),
    .s_axis_tlast  (s_axis_sou_tlast),
    .s_axis_tvalid (s_axis_sou_tvalid),
    .s_axis_tready (s_axis_sou_tready),
    .m_axis_tdata  (sou_i_tdata),
    .m_axis_tlast  (sou_i_tlast),
    .m_axis_tvalid (sou_i_tvalid),
    .m_axis_tready (sou_i_tready)
  );
  assign sou_i_tdest = compute_switch_tdest(sou_i_tdata, SW_DEST_SOU);

  //-------------------------------------------------
  // Switch
  //-------------------------------------------------
  // Track the input packet state
  localparam [0:0] PKT_ST_HEAD = 1'b0;
  localparam [0:0] PKT_ST_BODY = 1'b1;
  reg [0:0] pkt_state = PKT_ST_HEAD;

  // The switch only accept packets on a single port at a time.
  wire sw_in_ready = |({sou_i_tready, nor_i_tready, est_i_tready, wst_i_tready, ter_i_tready});
  wire sw_in_valid = |({sou_i_tvalid, nor_i_tvalid, est_i_tvalid, wst_i_tvalid, ter_i_tvalid});
  wire sw_in_last  = |({sou_i_tlast & sou_i_tvalid, nor_i_tlast & nor_i_tvalid,
                        est_i_tlast & est_i_tvalid, wst_i_tlast & wst_i_tvalid,
                        ter_i_tlast & ter_i_tvalid});

  always @(posedge clk) begin
    if (reset) begin
      pkt_state <= PKT_ST_HEAD;
    end else if (sw_in_valid & sw_in_ready) begin
      pkt_state <= sw_in_last ? PKT_ST_HEAD : PKT_ST_BODY;
    end
  end

  // The switch requires the allocation to stay valid until the
  // end of the packet. We also might need to keep the previous
  // packet's allocation to compute the current one
  wire [2:0] switch_alloc;
  reg  [2:0] prev_switch_alloc = SW_DEST_TER;
  reg  [2:0] pkt_switch_alloc  = SW_DEST_TER;

  always @(posedge clk) begin
    if (reset) begin
      prev_switch_alloc <= SW_DEST_TER;
      pkt_switch_alloc <= SW_DEST_TER;
    end else if (sw_in_valid & sw_in_ready) begin
      if (pkt_state == PKT_ST_HEAD)
        pkt_switch_alloc <= switch_alloc;
      if (sw_in_last)
        prev_switch_alloc <= switch_alloc;
    end
  end

  assign switch_alloc = (sw_in_valid && pkt_state == PKT_ST_HEAD) ? 
    compute_switch_alloc({sou_i_tvalid, nor_i_tvalid, est_i_tvalid, wst_i_tvalid, ter_i_tvalid}, prev_switch_alloc) :
    pkt_switch_alloc;

  wire ter_tdest_discard;
  axis_switch #(
    .DATA_W(WIDTH), .DEST_W(1), .IN_PORTS(5), .OUT_PORTS(5)
  ) switch_i (
    .clk           (clk),
    .reset         (reset),
    .s_axis_tdata  ({sou_i_tdata , nor_i_tdata , est_i_tdata , wst_i_tdata , ter_i_tdata }),
    .s_axis_tdest  ({sou_i_tdest , nor_i_tdest , est_i_tdest , wst_i_tdest , ter_i_tdest }), 
    .s_axis_tlast  ({sou_i_tlast , nor_i_tlast , est_i_tlast , wst_i_tlast , ter_i_tlast }),
    .s_axis_tvalid ({sou_i_tvalid, nor_i_tvalid, est_i_tvalid, wst_i_tvalid, ter_i_tvalid}),
    .s_axis_tready ({sou_i_tready, nor_i_tready, est_i_tready, wst_i_tready, ter_i_tready}),
    .s_axis_alloc  (switch_alloc),
    .m_axis_tdata  ({m_axis_sou_tdata,  m_axis_nor_tdata,  m_axis_est_tdata,  m_axis_wst_tdata,  m_axis_ter_tdata }),
    .m_axis_tdest  ({m_axis_sou_tdest,  m_axis_nor_tdest,  m_axis_est_tdest,  m_axis_wst_tdest,  ter_tdest_discard}),
    .m_axis_tlast  ({m_axis_sou_tlast,  m_axis_nor_tlast,  m_axis_est_tlast,  m_axis_wst_tlast,  m_axis_ter_tlast }),
    .m_axis_tvalid ({m_axis_sou_tvalid, m_axis_nor_tvalid, m_axis_est_tvalid, m_axis_wst_tvalid, m_axis_ter_tvalid}),
    .m_axis_tready ({m_axis_sou_tready, m_axis_nor_tready, m_axis_est_tready, m_axis_wst_tready, m_axis_ter_tready})
  );


endmodule

